package com.uxsino.commons.tftp;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.util.Timer;
import java.util.TimerTask;
import java.util.function.Consumer;
import java.util.function.Function;

import org.apache.commons.net.io.ToNetASCIIInputStream;
import org.apache.commons.net.io.ToNetASCIIOutputStream;
import org.apache.commons.net.tftp.TFTP;
import org.apache.commons.net.tftp.TFTPAckPacket;
import org.apache.commons.net.tftp.TFTPDataPacket;
import org.apache.commons.net.tftp.TFTPErrorPacket;
import org.apache.commons.net.tftp.TFTPPacket;
import org.apache.commons.net.tftp.TFTPPacketException;
import org.apache.commons.net.tftp.TFTPReadRequestPacket;
import org.apache.commons.net.tftp.TFTPWriteRequestPacket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.uxsino.commons.tftp.TFTPServer.ServerMode;

/**
 * 
 * 描述：TFTP 会话
 * @author <a href="mailto:royrxc@gmail.com">Ran</a>
 * 保存文件传递的基本状态标志信息
 * @date 2018年3月7日
 */

public class TFTPSession {
    private static final Logger LOGGER = LoggerFactory.getLogger(TFTPSession.class);

    /**
     * 标识服务器的读写方式
     */
    private ServerMode mode = ServerMode.GET_AND_PUT;
    
    /**
     * 读取数据的缓存大小
     */
    private int buf_size = TFTPServer.BUF_SIZE;

    /**
     * 多次发送数据，容错处理
     */
    private int times = 1;

    /**
     * 异步发送器
     */
    private Timer timer = new Timer(true);

    /**
     * 消息发送接口
     */
    private Consumer<TFTPPacket> sender;
    
    /**
     * 写文件接口，根据 传递的文件key，返回输出流通道。
     */
    private Function<String, OutputStream> writerFn;

    /**
     * 读文件接口， 根据传递的key，返回文件输入流通道
     */
    private Function<String, InputStream> readerFn;

    /**
     * 发送文件的最大长度，发送时有效
     */
    private int maxFileLength = 0;

    /**
     * 根据readerFn获取的输入流
     */
    private InputStream reader;

    /**
     * 更具writerFn 获取的输出流
     */
    private OutputStream writer;
    
    /**
     * 该session 是否已经处理完成
     */
    private boolean end = false;

    /**
     * 发送或者接受文件的数据包次序
     */
    private int block;
    
    public boolean isEnd(){
        return this.end;
    }
    
    public TFTPSession writerFn(Function<String, OutputStream> writerFn){
        this.writerFn = writerFn;
        return this;
    }
    
    public TFTPSession times(int times){
        this.times = times;
        return this;
    }
    
    public TFTPSession readerFn(Function<String, InputStream> readerFn){
        this.readerFn = readerFn;
        return this;
    }
    
    public TFTPSession sender(Consumer<TFTPPacket> sender){
        this.sender = sender;
        return this;
    }
    
    private TFTPSession(){
        LOGGER.info("TFTP session new!");
    }
    
    public static TFTPSession of(){
        return new TFTPSession();
    }
    
    public synchronized void destory(){
        if(this.reader != null){
            try {
                this.reader.close();
            } catch (IOException e) {
                LOGGER.error("Close input stream error: ", e);
            }
        }
        if(this.writer != null){
            try {
                this.writer.flush();
                this.writer.close();
            } catch (IOException e) {
                LOGGER.error("Close output stream error: ", e);
            }
        }
    }

    public synchronized void end(){
        this.end = true;
        this.destory();
    }

    public TFTPSession mode(ServerMode mode){
        this.mode = mode;
        return this;
    }
    
    private void toModeIo(int mode){
        if(mode == TFTP.NETASCII_MODE){
            if(this.reader != null  && !(this.reader instanceof ToNetASCIIInputStream)){
                this.reader = new ToNetASCIIInputStream(this.reader);
            }
            if(this.writer != null  && !(this.writer instanceof ToNetASCIIOutputStream)){
                this.writer = new ToNetASCIIOutputStream(this.writer);
            }
        }
    }
    
    private void send(TFTPPacket packet){
        try{
            this.sender.accept(packet);
        }catch(Exception e){
            LOGGER.error("TFTP send packet err: ", e);
            this.end();
        }
    }
    
    private TFTPErrorPacket error(InetAddress address, int port, int type, String msg){
        TFTPErrorPacket error = new TFTPErrorPacket(address, port, type, msg);
        LOGGER.error("TFTP session error: "+ msg);
        return error;
    }
    
    public synchronized boolean process(TFTPPacket pp) throws TFTPPacketException{
        if(this.sender == null){
            return false;
        }
        if ((pp.getType() == TFTPPacket.READ_REQUEST || TFTPPacket.ACKNOWLEDGEMENT == pp.getType())
                && ServerMode.PUT_ONLY.equals(this.mode)) {
            this.send(this.error(pp.getAddress(), pp.getPort(), TFTPErrorPacket.ILLEGAL_OPERATION,
                "read mode not supported. "));
            this.end();
            return false;
        } else if ((pp.getType() == TFTPPacket.WRITE_REQUEST || TFTPPacket.DATA == pp.getType())
                && ServerMode.GET_ONLY.equals(this.mode)) {
            this.send(this.error(pp.getAddress(), pp.getPort(), TFTPErrorPacket.ILLEGAL_OPERATION,
                "write mode not supported. "));
            this.end();
            return false;
        }
        
        switch(pp.getType()){
            case TFTPPacket.READ_REQUEST://请求读取数据
                {
                    try{
                        TFTPReadRequestPacket p = (TFTPReadRequestPacket) pp;
                        byte[] data = new byte[buf_size];
                        String fileName = p.getFilename();
                        if(this.reader == null && this.readerFn != null){//开启一个新的输入流
                            this.reader = this.readerFn.apply(fileName);
                            if(this.reader != null){
                                this.maxFileLength = this.reader.available();
                            }
                        }
                        if(this.reader == null){
                    TFTPErrorPacket error = new TFTPErrorPacket(pp.getAddress(), pp.getPort(),
                        TFTPErrorPacket.FILE_NOT_FOUND, "file " + fileName + " not found. ");
                            this.sender.accept(error);
                            LOGGER.error("TFTP session reader not found !");
                            this.end();
                            return false;
                        }
                        toModeIo(p.getMode());//转换输入流
                        int size = 0;
                        try {
                            this.reader.reset();
                            this.reader.skip(0);
                            size = this.reader.read(data, 0, this.buf_size);
                        } catch (IOException e) {
                    TFTPErrorPacket error = new TFTPErrorPacket(pp.getAddress(), pp.getPort(),
                        TFTPErrorPacket.UNDEFINED, e.getMessage());
                            this.sender.accept(error);
                            this.end();
                            LOGGER.error("TFTP session file read error: ", e);
                            return false;
                        }
                        if(size != this.buf_size){
                            if(size <= 0){
                                data = new byte[0];
                            }else{
                                byte[] buf = new byte[size];
                                System.arraycopy(data, 0, buf, 0, size);
                                data = buf;
                            }
                        }
                        this.block = 1;
                        TFTPDataPacket resp = new TFTPDataPacket(pp.getAddress(), pp.getPort(), 1, data);
                        timer.schedule(new Task(resp).counter(1).sender(this::send).times(this.times) , 0);
                    }catch (Exception e){
                TFTPErrorPacket error = new TFTPErrorPacket(pp.getAddress(), pp.getPort(), TFTPErrorPacket.UNDEFINED,
                    e.getMessage());
                        this.send(error);
                        LOGGER.error("TFTP session file read error: ", e);
                        this.end();
                        return false;
                    }
                }
                break;
        case TFTPPacket.ACKNOWLEDGEMENT: {
                    TFTPAckPacket p = (TFTPAckPacket)pp;
                    //判断文件是否已经发送完成
                    try {
                        if((p.getBlockNumber()-1)*this.buf_size >= this.maxFileLength){
                    TFTPPacket resp = new TFTPDataPacket(p.getAddress(), p.getPort(), p.getBlockNumber() + 1,
                        new byte[0]);
                            timer.schedule(new Task(resp).counter(1).sender(this::send).times(this.times), 0);
                            this.end();
                            return true;
                        }
                        
                        byte[] data = new byte[this.buf_size];
                        this.reader.reset();
                        this.reader.skip((p.getBlockNumber()-1)*this.buf_size);
                        int size = this.reader.read(data, 0, this.buf_size);
                LOGGER.debug(
                    "read file from " + ((p.getBlockNumber() - 1) * this.buf_size) + ", length " + size + "bytes");
                        if(size < this.buf_size){
                            if(size <= -1){
                                data = new byte[0];
                            }else{
                                this.end();
                                byte[] buf = new byte[size];
                                System.arraycopy(data, 0, buf, 0, size);
                                data = buf;
                            }
                        }
                        TFTPPacket resp = new TFTPDataPacket(p.getAddress(), p.getPort(), p.getBlockNumber()+1, data);
                        timer.schedule(new Task(resp).counter(1).sender(this::send).times(this.times), 0);
                    } catch (Exception e) {
                        this.end();
                        LOGGER.error("read file error ", e);
                TFTPErrorPacket error = new TFTPErrorPacket(pp.getAddress(), pp.getPort(), TFTPErrorPacket.ERROR,
                    " read file error");
                        this.send(error);
                        return false;
                    }
                }
                break;
            case TFTPPacket.WRITE_REQUEST://请求写文件,此时数据尚未到达
                {
                    try{
                        TFTPWriteRequestPacket p = (TFTPWriteRequestPacket)pp;
                        if(this.writer == null && this.writerFn != null){
                            this.writer = this.writerFn.apply(p.getFilename());
                        }
                        if(this.writer == null){
                            LOGGER.error("TFTP server writer not found ");
                    TFTPErrorPacket error = new TFTPErrorPacket(pp.getAddress(), pp.getPort(),
                        TFTPErrorPacket.FILE_NOT_FOUND, " write file error");
                            this.send(error);
                            this.end();
                            return false;
                        }
                        toModeIo(p.getMode());
                LOGGER.debug(
                    "receive file " + p.getFilename() + " from " + p.getAddress().getHostAddress() + ":" + p.getPort());
                        TFTPPacket resp = new TFTPAckPacket(p.getAddress(), p.getPort(), this.block);
                        timer.schedule(new Task(resp).counter(1).sender(this::send).times(this.times) , 0);
                    }catch(Exception e){
                        LOGGER.error("write file error ", e);
                TFTPErrorPacket error = new TFTPErrorPacket(pp.getAddress(), pp.getPort(), TFTPErrorPacket.ERROR,
                    " read file error");
                        this.send(error);
                        this.end();
                        return false;
                    }
                }
                break;
            case TFTPPacket.DATA://开始写数据
                {
                    try{
                        TFTPDataPacket p = (TFTPDataPacket) pp;
                        byte[] data = new byte[p.getDataLength()];
                        System.arraycopy(p.getData(), p.getDataOffset(), data, 0, data.length);
                        if(p.getBlockNumber() != this.block + 1){
                            this.end();
                            return false;
                        }
                        this.block++;
                        if(data != null){
                            this.writer.write(data);
                        }
                        LOGGER.debug("received data : "+ data.length +" bytes");
                        TFTPPacket resp = new TFTPAckPacket(p.getAddress(), p.getPort(), this.block);
                        timer.schedule(new Task(resp).counter(1).sender(sender).times(this.times), 0);
                        if(data == null || data.length < this.buf_size){
                            this.end();
                        }
                    }catch(Exception e){
                        LOGGER.error("read file error ", e);
                TFTPErrorPacket error = new TFTPErrorPacket(pp.getAddress(), pp.getPort(), TFTPErrorPacket.ERROR,
                    " write file error : " + e.getMessage());
                        this.send(error);
                        this.end();
                        return false;
                    }
                }
                break;
        case TFTPPacket.ERROR: {
                    this.end();
                    return false;
                }
        }
        return true;
    }
    
    private class Task extends TimerTask{
        private int counter = 1;

        private Consumer<TFTPPacket> sender;

        private TFTPPacket packet;

        private int times = 3;
        
        public Task(TFTPPacket packet){
            this.packet = packet;
        }

        public Task counter(int counter){
            this.counter = counter;
            return this;
        }
        
        public Task times(int times){
            this.times = times;
            return this;
        }

        public Task sender(Consumer<TFTPPacket> sender){
            this.sender = sender;
            return this;
        }
        
        @Override
        public void run() {
            this.sender.accept(packet);
            if(this.counter < this.times){
                this.counter(this.counter + 1);
                timer.schedule(new Task(this.packet).counter(this.counter).times(this.times).sender(this.sender),
                    (int) Math.pow(2, times) * 3000);
            }
        }
    }
}
